Your Name: Paloma Arellano
For this assignment, you will do an ablation study on the DCGAN model discussed in class and implemented WGAN with weight clipping and (optional) WGAN with gradient penalty.
An ablation study measures performance changes after changing certain components in the AI system. The goal is to understand the contribution on each component for the overall system.
Here is the copy of the code implementation from course website. Please run the code to obtain the result and use it as a baseline to compare the results with the following the ablation tasks.
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmarks = False
os.environ['PYTHONHASHSEED'] = str(manualSeed)
# Root directory for dataset
# dataroot = "data/celeba"
# Number of workers for dataloader
workers = 1
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
#image_size = 64
image_size = 32
# Number of channels in the training images. For color images this is 3
#nc = 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
#ngf = 64
ngf = 8
# Size of feature maps in discriminator
#ndf = 64
ndf = 8
# Number of training epochs
num_epochs = 5
num_epochs_wgan = 15
num_iters = 250
# Learning rate for optimizers
lr = 0.0002
lr_rms = 5e-4
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Several useful functions
def initialize_net(net_class, init_method, device, ngpu):
# Create the generator
net_inst = net_class(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
net_inst = nn.DataParallel(net_inst, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
if init_method is not None:
net_inst.apply(init_method)
# Print the model
print(net_inst)
return net_inst
def plot_GAN_loss(losses, labels):
plt.figure(figsize=(10,5))
plt.title("Losses During Training")
for loss, label in zip(losses, labels):
plt.plot(loss,label=f"{label}")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
def plot_real_fake_images(real_batch, fake_batch):
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(fake_batch[-1],(1,2,0)))
plt.show()
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Download the MNIST dataset
dataset = dset.MNIST(
'data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(image_size), # Resize from 28 x 28 to 32 x 32 (so power of 2)
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Random Seed: 999 Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 283709977.58it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 34501764.12it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 129083214.75it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 5312473.16it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
<matplotlib.image.AxesImage at 0x7f3c61f0c820>
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution, state size. nz x 1 x 1
nn.ConvTranspose2d( nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), # inplace ReLU
# current state size. (ngf*4) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# current state size. (ngf*2) x 8 x 8
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# current state size. ngf x 16 x 16
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
# current state size. nc x 32 x 32
# Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
nn.Sigmoid() # Produce probability
)
def forward(self, input):
return self.main(input)
# Initialize networks
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 1.4303 Loss_G: 0.7566 D(x): 0.4824 D(G(z)): 0.4978 / 0.4725
[0/5][50/469] Loss_D: 0.5966 Loss_G: 1.3870 D(x): 0.7972 D(G(z)): 0.3009 / 0.2563
[0/5][100/469] Loss_D: 0.2972 Loss_G: 2.1574 D(x): 0.8821 D(G(z)): 0.1524 / 0.1224
[0/5][150/469] Loss_D: 0.1240 Loss_G: 2.8461 D(x): 0.9539 D(G(z)): 0.0729 / 0.0660
[0/5][200/469] Loss_D: 0.0684 Loss_G: 3.4240 D(x): 0.9768 D(G(z)): 0.0436 / 0.0386
[0/5][250/469] Loss_D: 0.0440 Loss_G: 3.7561 D(x): 0.9884 D(G(z)): 0.0317 / 0.0276
[0/5][300/469] Loss_D: 0.0334 Loss_G: 4.1022 D(x): 0.9938 D(G(z)): 0.0268 / 0.0196
[0/5][350/469] Loss_D: 0.0226 Loss_G: 4.7683 D(x): 0.9904 D(G(z)): 0.0127 / 0.0102
[0/5][400/469] Loss_D: 0.0177 Loss_G: 4.8415 D(x): 0.9927 D(G(z)): 0.0103 / 0.0084
[0/5][450/469] Loss_D: 0.0164 Loss_G: 5.2067 D(x): 0.9918 D(G(z)): 0.0081 / 0.0062
[1/5][0/469] Loss_D: 0.0157 Loss_G: 5.2740 D(x): 0.9916 D(G(z)): 0.0073 / 0.0066
[1/5][50/469] Loss_D: 0.0397 Loss_G: 4.1889 D(x): 0.9783 D(G(z)): 0.0175 / 0.0161
[1/5][100/469] Loss_D: 0.0328 Loss_G: 4.4058 D(x): 0.9894 D(G(z)): 0.0216 / 0.0145
[1/5][150/469] Loss_D: 0.0364 Loss_G: 4.4479 D(x): 0.9835 D(G(z)): 0.0193 / 0.0140
[1/5][200/469] Loss_D: 0.0261 Loss_G: 4.8180 D(x): 0.9894 D(G(z)): 0.0152 / 0.0095
[1/5][250/469] Loss_D: 0.0729 Loss_G: 3.6828 D(x): 0.9707 D(G(z)): 0.0417 / 0.0279
[1/5][300/469] Loss_D: 0.0478 Loss_G: 3.9377 D(x): 0.9706 D(G(z)): 0.0175 / 0.0218
[1/5][350/469] Loss_D: 0.2376 Loss_G: 2.3146 D(x): 0.8734 D(G(z)): 0.0937 / 0.1040
[1/5][400/469] Loss_D: 0.1190 Loss_G: 3.6090 D(x): 0.9642 D(G(z)): 0.0761 / 0.0347
[1/5][450/469] Loss_D: 0.2341 Loss_G: 2.1788 D(x): 0.8706 D(G(z)): 0.0828 / 0.1257
[2/5][0/469] Loss_D: 0.1488 Loss_G: 2.7907 D(x): 0.9103 D(G(z)): 0.0504 / 0.0754
[2/5][50/469] Loss_D: 0.2078 Loss_G: 4.2654 D(x): 0.9740 D(G(z)): 0.1591 / 0.0169
[2/5][100/469] Loss_D: 0.4250 Loss_G: 2.2731 D(x): 0.9103 D(G(z)): 0.2623 / 0.1252
[2/5][150/469] Loss_D: 0.2491 Loss_G: 2.8262 D(x): 0.9253 D(G(z)): 0.1522 / 0.0681
[2/5][200/469] Loss_D: 0.2782 Loss_G: 2.8669 D(x): 0.9052 D(G(z)): 0.1548 / 0.0688
[2/5][250/469] Loss_D: 0.2639 Loss_G: 2.6223 D(x): 0.8578 D(G(z)): 0.0980 / 0.0814
[2/5][300/469] Loss_D: 0.5437 Loss_G: 1.9817 D(x): 0.8708 D(G(z)): 0.3163 / 0.1578
[2/5][350/469] Loss_D: 0.3995 Loss_G: 2.1066 D(x): 0.8904 D(G(z)): 0.2361 / 0.1383
[2/5][400/469] Loss_D: 0.3721 Loss_G: 2.7850 D(x): 0.8977 D(G(z)): 0.2217 / 0.0716
[2/5][450/469] Loss_D: 0.4028 Loss_G: 3.0969 D(x): 0.9254 D(G(z)): 0.2669 / 0.0525
[3/5][0/469] Loss_D: 0.3565 Loss_G: 2.0037 D(x): 0.8438 D(G(z)): 0.1614 / 0.1494
[3/5][50/469] Loss_D: 0.8738 Loss_G: 2.7901 D(x): 0.9695 D(G(z)): 0.5489 / 0.0702
[3/5][100/469] Loss_D: 0.3342 Loss_G: 2.3238 D(x): 0.8686 D(G(z)): 0.1689 / 0.1089
[3/5][150/469] Loss_D: 1.0142 Loss_G: 4.1669 D(x): 0.9655 D(G(z)): 0.5988 / 0.0181
[3/5][200/469] Loss_D: 0.4335 Loss_G: 2.5712 D(x): 0.8926 D(G(z)): 0.2640 / 0.0858
[3/5][250/469] Loss_D: 0.5521 Loss_G: 1.4452 D(x): 0.6382 D(G(z)): 0.0675 / 0.2619
[3/5][300/469] Loss_D: 0.6437 Loss_G: 0.8712 D(x): 0.5871 D(G(z)): 0.0602 / 0.4422
[3/5][350/469] Loss_D: 0.8271 Loss_G: 2.8187 D(x): 0.9563 D(G(z)): 0.5172 / 0.0699
[3/5][400/469] Loss_D: 0.5000 Loss_G: 1.9154 D(x): 0.8155 D(G(z)): 0.2389 / 0.1680
[3/5][450/469] Loss_D: 0.4576 Loss_G: 1.6751 D(x): 0.8357 D(G(z)): 0.2277 / 0.2098
[4/5][0/469] Loss_D: 0.4755 Loss_G: 2.7527 D(x): 0.8764 D(G(z)): 0.2790 / 0.0726
[4/5][50/469] Loss_D: 0.4261 Loss_G: 1.7925 D(x): 0.8235 D(G(z)): 0.1913 / 0.1911
[4/5][100/469] Loss_D: 0.8147 Loss_G: 1.2830 D(x): 0.5204 D(G(z)): 0.0914 / 0.3016
[4/5][150/469] Loss_D: 1.8502 Loss_G: 0.6741 D(x): 0.1963 D(G(z)): 0.0276 / 0.5426
[4/5][200/469] Loss_D: 0.7464 Loss_G: 3.2509 D(x): 0.9556 D(G(z)): 0.4870 / 0.0442
[4/5][250/469] Loss_D: 0.4471 Loss_G: 1.6267 D(x): 0.8615 D(G(z)): 0.2426 / 0.2208
[4/5][300/469] Loss_D: 0.4634 Loss_G: 1.8377 D(x): 0.7386 D(G(z)): 0.1202 / 0.1918
[4/5][350/469] Loss_D: 0.5774 Loss_G: 1.2724 D(x): 0.6733 D(G(z)): 0.1382 / 0.3064
[4/5][400/469] Loss_D: 0.5334 Loss_G: 1.0508 D(x): 0.6553 D(G(z)): 0.0703 / 0.3765
[4/5][450/469] Loss_D: 0.8457 Loss_G: 3.2769 D(x): 0.9519 D(G(z)): 0.5283 / 0.0447
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# Generator Code
class Generator_woBN(nn.Module):
def __init__(self, ngpu):
super(Generator_woBN, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
################################ YOUR CODE ################################
# input is Z, going into a convolution, state size. nz x 1 x 1
nn.ConvTranspose2d( nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
#nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), # inplace ReLU
# current state size. (ngf*4) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
#nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# current state size. (ngf*2) x 8 x 8
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
#nn.BatchNorm2d(ngf),
nn.ReLU(True),
# current state size. ngf x 16 x 16
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
# current state size. nc x 32 x 32
# Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
nn.Tanh()
############################# END YOUR CODE ##############################
)
def forward(self, input):
return self.main(input)
class Discriminator_woBN(nn.Module):
def __init__(self, ngpu):
super(Discriminator_woBN, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
################################ YOUR CODE ################################
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
#nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
#nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
nn.Sigmoid() # Produce probability
############################# END YOUR CODE ##############################
)
def forward(self, input):
return self.main(input)
netG_noBN = initialize_net(Generator_woBN, weights_init, device, ngpu)
netD_noBN = initialize_net(Discriminator_woBN, weights_init, device, ngpu)
Generator_woBN(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): ReLU(inplace=True)
(2): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): ReLU(inplace=True)
(4): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): Tanh()
)
)
Discriminator_woBN(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): LeakyReLU(negative_slope=0.2, inplace=True)
(4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(5): LeakyReLU(negative_slope=0.2, inplace=True)
(6): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(7): Sigmoid()
)
)
# Setup Adam optimizers for both G and D
optimizerD_noBN = optim.Adam(netD_noBN.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_noBN = optim.Adam(netG_noBN.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD_noBN.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD_noBN(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG_noBN(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD_noBN(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD_noBN.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG_noBN.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD_noBN(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG_noBN.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG_noBN(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Starting Training Loop... [0/5][0/469] Loss_D: 1.3862 Loss_G: 0.6931 D(x): 0.5001 D(G(z)): 0.5000 / 0.5000 [0/5][50/469] Loss_D: 1.1952 Loss_G: 0.4654 D(x): 0.8301 D(G(z)): 0.6348 / 0.6280 [0/5][100/469] Loss_D: 0.8022 Loss_G: 0.9934 D(x): 0.7311 D(G(z)): 0.3842 / 0.3708 [0/5][150/469] Loss_D: 0.7584 Loss_G: 0.6824 D(x): 0.9519 D(G(z)): 0.5046 / 0.5058 [0/5][200/469] Loss_D: 1.0214 Loss_G: 0.5153 D(x): 0.9189 D(G(z)): 0.6033 / 0.5978 [0/5][250/469] Loss_D: 1.2341 Loss_G: 0.4859 D(x): 0.7746 D(G(z)): 0.6177 / 0.6153 [0/5][300/469] Loss_D: 0.9598 Loss_G: 0.6933 D(x): 0.7836 D(G(z)): 0.4970 / 0.5001 [0/5][350/469] Loss_D: 0.6787 Loss_G: 0.8238 D(x): 0.9145 D(G(z)): 0.4390 / 0.4392 [0/5][400/469] Loss_D: 0.6651 Loss_G: 0.7524 D(x): 0.9737 D(G(z)): 0.4711 / 0.4715 [0/5][450/469] Loss_D: 0.6900 Loss_G: 0.7105 D(x): 0.9877 D(G(z)): 0.4919 / 0.4915 [1/5][0/469] Loss_D: 0.7182 Loss_G: 0.6823 D(x): 0.9883 D(G(z)): 0.5064 / 0.5055 [1/5][50/469] Loss_D: 0.7246 Loss_G: 0.6931 D(x): 0.9721 D(G(z)): 0.4994 / 0.5001 [1/5][100/469] Loss_D: 0.6928 Loss_G: 0.7128 D(x): 0.9869 D(G(z)): 0.4921 / 0.4903 [1/5][150/469] Loss_D: 0.6932 Loss_G: 0.6997 D(x): 0.9958 D(G(z)): 0.4978 / 0.4967 [1/5][200/469] Loss_D: 0.7105 Loss_G: 0.6813 D(x): 0.9951 D(G(z)): 0.5059 / 0.5059 [1/5][250/469] Loss_D: 0.7194 Loss_G: 0.6701 D(x): 0.9979 D(G(z)): 0.5119 / 0.5117 [1/5][300/469] Loss_D: 0.7978 Loss_G: 0.6027 D(x): 0.9970 D(G(z)): 0.5482 / 0.5474 [1/5][350/469] Loss_D: 0.8425 Loss_G: 0.5770 D(x): 0.9897 D(G(z)): 0.5647 / 0.5616 [1/5][400/469] Loss_D: 0.7846 Loss_G: 0.6458 D(x): 0.9680 D(G(z)): 0.5283 / 0.5243 [1/5][450/469] Loss_D: 0.7694 Loss_G: 0.6759 D(x): 0.9510 D(G(z)): 0.5103 / 0.5087 [2/5][0/469] Loss_D: 0.7779 Loss_G: 0.6590 D(x): 0.9562 D(G(z)): 0.5181 / 0.5174 [2/5][50/469] Loss_D: 0.8922 Loss_G: 0.5653 D(x): 0.9526 D(G(z)): 0.5692 / 0.5683 [2/5][100/469] Loss_D: 0.7689 Loss_G: 0.6693 D(x): 0.9547 D(G(z)): 0.5140 / 0.5121 [2/5][150/469] Loss_D: 0.7870 Loss_G: 0.6666 D(x): 0.9484 D(G(z)): 0.5192 / 0.5138 [2/5][200/469] Loss_D: 0.8293 Loss_G: 0.6275 D(x): 0.9539 D(G(z)): 0.5415 / 0.5342 [2/5][250/469] Loss_D: 0.8026 Loss_G: 0.6535 D(x): 0.9456 D(G(z)): 0.5239 / 0.5207 [2/5][300/469] Loss_D: 0.7496 Loss_G: 0.6929 D(x): 0.9456 D(G(z)): 0.4983 / 0.5002 [2/5][350/469] Loss_D: 0.8283 Loss_G: 0.6217 D(x): 0.9550 D(G(z)): 0.5404 / 0.5373 [2/5][400/469] Loss_D: 0.9913 Loss_G: 0.5764 D(x): 0.9004 D(G(z)): 0.5752 / 0.5673 [2/5][450/469] Loss_D: 1.0604 Loss_G: 0.5013 D(x): 0.9082 D(G(z)): 0.6131 / 0.6078 [3/5][0/469] Loss_D: 1.0854 Loss_G: 0.4794 D(x): 0.9036 D(G(z)): 0.6229 / 0.6202 [3/5][50/469] Loss_D: 1.0445 Loss_G: 0.5252 D(x): 0.8725 D(G(z)): 0.5927 / 0.5920 [3/5][100/469] Loss_D: 0.9873 Loss_G: 0.5963 D(x): 0.8615 D(G(z)): 0.5615 / 0.5520 [3/5][150/469] Loss_D: 0.9200 Loss_G: 0.6243 D(x): 0.8908 D(G(z)): 0.5462 / 0.5363 [3/5][200/469] Loss_D: 0.9313 Loss_G: 0.6356 D(x): 0.8676 D(G(z)): 0.5415 / 0.5302 [3/5][250/469] Loss_D: 0.8727 Loss_G: 0.6393 D(x): 0.9122 D(G(z)): 0.5402 / 0.5280 [3/5][300/469] Loss_D: 1.0264 Loss_G: 0.5378 D(x): 0.7610 D(G(z)): 0.5197 / 0.5859 [3/5][350/469] Loss_D: 0.8204 Loss_G: 0.7601 D(x): 0.8250 D(G(z)): 0.4592 / 0.4687 [3/5][400/469] Loss_D: 0.5009 Loss_G: 0.9467 D(x): 0.9993 D(G(z)): 0.3934 / 0.3883 [3/5][450/469] Loss_D: 0.7770 Loss_G: 0.7528 D(x): 0.9058 D(G(z)): 0.4851 / 0.4715 [4/5][0/469] Loss_D: 0.9830 Loss_G: 0.6197 D(x): 0.8272 D(G(z)): 0.5405 / 0.5386 [4/5][50/469] Loss_D: 0.9048 Loss_G: 0.6230 D(x): 0.8457 D(G(z)): 0.5166 / 0.5375 [4/5][100/469] Loss_D: 0.8361 Loss_G: 0.6865 D(x): 0.9080 D(G(z)): 0.5200 / 0.5040 [4/5][150/469] Loss_D: 0.7865 Loss_G: 0.7342 D(x): 0.8882 D(G(z)): 0.4842 / 0.4801 [4/5][200/469] Loss_D: 0.7616 Loss_G: 0.7036 D(x): 0.9398 D(G(z)): 0.5023 / 0.4950 [4/5][250/469] Loss_D: 0.7493 Loss_G: 0.7195 D(x): 0.9160 D(G(z)): 0.4807 / 0.4874 [4/5][300/469] Loss_D: 0.8603 Loss_G: 0.6757 D(x): 0.8640 D(G(z)): 0.4937 / 0.5109 [4/5][350/469] Loss_D: 0.6998 Loss_G: 0.7947 D(x): 0.9380 D(G(z)): 0.4681 / 0.4521 [4/5][400/469] Loss_D: 0.7850 Loss_G: 0.6802 D(x): 0.9579 D(G(z)): 0.5226 / 0.5073 [4/5][450/469] Loss_D: 0.6922 Loss_G: 0.7281 D(x): 0.9715 D(G(z)): 0.4841 / 0.4831
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
################################ YOUR CODE ################################
# Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label_r = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
#output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
#errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
#errD_real.backward()
#D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label_f = torch.full((b_size,), fake_label, device=device)
# label.fill_(fake_label)
in_tot = torch.cat((real_cpu, fake), dim=0)
label_tot = torch.cat((label_r, label_f), dim=0)
# Classify all fake batch with D
output = netD(in_tot.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD = criterion(output, label_tot)
# Calculate the gradients for this batch
errD.backward()
#D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
#errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################ END YOUR CODE ##############################
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label = torch.full((b_size,), real_label, device=device) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(G(z)): %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 0.7323 Loss_G: 0.7700 D(G(z)): 0.4673
[0/5][50/469] Loss_D: 0.0267 Loss_G: 0.1055 D(G(z)): 0.9001
[0/5][100/469] Loss_D: 0.0101 Loss_G: 0.0410 D(G(z)): 0.9599
[0/5][150/469] Loss_D: 0.0059 Loss_G: 0.0247 D(G(z)): 0.9756
[0/5][200/469] Loss_D: 0.0038 Loss_G: 0.0168 D(G(z)): 0.9833
[0/5][250/469] Loss_D: 0.0027 Loss_G: 0.0124 D(G(z)): 0.9877
[0/5][300/469] Loss_D: 0.0020 Loss_G: 0.0095 D(G(z)): 0.9906
[0/5][350/469] Loss_D: 0.0016 Loss_G: 0.0074 D(G(z)): 0.9926
[0/5][400/469] Loss_D: 0.0013 Loss_G: 0.0060 D(G(z)): 0.9940
[0/5][450/469] Loss_D: 0.0011 Loss_G: 0.0048 D(G(z)): 0.9952
[1/5][0/469] Loss_D: 0.0010 Loss_G: 0.0046 D(G(z)): 0.9954
[1/5][50/469] Loss_D: 0.0008 Loss_G: 0.0039 D(G(z)): 0.9961
[1/5][100/469] Loss_D: 0.0007 Loss_G: 0.0034 D(G(z)): 0.9966
[1/5][150/469] Loss_D: 0.0006 Loss_G: 0.0030 D(G(z)): 0.9970
[1/5][200/469] Loss_D: 0.0005 Loss_G: 0.0027 D(G(z)): 0.9973
[1/5][250/469] Loss_D: 0.0005 Loss_G: 0.0023 D(G(z)): 0.9977
[1/5][300/469] Loss_D: 0.0004 Loss_G: 0.0022 D(G(z)): 0.9978
[1/5][350/469] Loss_D: 0.0004 Loss_G: 0.0019 D(G(z)): 0.9981
[1/5][400/469] Loss_D: 0.0003 Loss_G: 0.0018 D(G(z)): 0.9982
[1/5][450/469] Loss_D: 0.0003 Loss_G: 0.0016 D(G(z)): 0.9984
[2/5][0/469] Loss_D: 0.0003 Loss_G: 0.0016 D(G(z)): 0.9984
[2/5][50/469] Loss_D: 0.0003 Loss_G: 0.0015 D(G(z)): 0.9985
[2/5][100/469] Loss_D: 0.0002 Loss_G: 0.0014 D(G(z)): 0.9986
[2/5][150/469] Loss_D: 0.0002 Loss_G: 0.0012 D(G(z)): 0.9988
[2/5][200/469] Loss_D: 0.0002 Loss_G: 0.0012 D(G(z)): 0.9988
[2/5][250/469] Loss_D: 0.0002 Loss_G: 0.0011 D(G(z)): 0.9989
[2/5][300/469] Loss_D: 0.0002 Loss_G: 0.0010 D(G(z)): 0.9990
[2/5][350/469] Loss_D: 0.0002 Loss_G: 0.0010 D(G(z)): 0.9990
[2/5][400/469] Loss_D: 0.0002 Loss_G: 0.0009 D(G(z)): 0.9991
[2/5][450/469] Loss_D: 0.0001 Loss_G: 0.0009 D(G(z)): 0.9991
[3/5][0/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][50/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][100/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][150/469] Loss_D: 0.0001 Loss_G: 0.0007 D(G(z)): 0.9993
[3/5][200/469] Loss_D: 0.0001 Loss_G: 0.0007 D(G(z)): 0.9993
[3/5][250/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][300/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][350/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][400/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[3/5][450/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][0/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][50/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][100/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][150/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][200/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][250/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][300/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][350/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][400/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][450/469] Loss_D: 0.0001 Loss_G: 0.0003 D(G(z)): 0.9997
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network
###########################
################################ YOUR CODE ################################
netG.zero_grad()
#label.fill_(real_label) # fake labels are real for generator cost
label = torch.full((b_size,), fake_label, device=device)
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = -criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
############################ END YOUR CODE ##############################
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 1.4907 Loss_G: -0.7539 D(x): 0.5020 D(G(z)): 0.5416 / 0.5227
[0/5][50/469] Loss_D: 0.5494 Loss_G: -0.2846 D(x): 0.8201 D(G(z)): 0.2899 / 0.2449
[0/5][100/469] Loss_D: 0.1402 Loss_G: -0.0742 D(x): 0.9497 D(G(z)): 0.0835 / 0.0712
[0/5][150/469] Loss_D: 0.0845 Loss_G: -0.0406 D(x): 0.9657 D(G(z)): 0.0473 / 0.0396
[0/5][200/469] Loss_D: 0.0198 Loss_G: -0.0107 D(x): 0.9914 D(G(z)): 0.0111 / 0.0106
[0/5][250/469] Loss_D: 0.0191 Loss_G: -0.0108 D(x): 0.9927 D(G(z)): 0.0116 / 0.0107
[0/5][300/469] Loss_D: 0.0177 Loss_G: -0.0107 D(x): 0.9944 D(G(z)): 0.0119 / 0.0107
[0/5][350/469] Loss_D: 0.0141 Loss_G: -0.0079 D(x): 0.9939 D(G(z)): 0.0080 / 0.0079
[0/5][400/469] Loss_D: 0.0095 Loss_G: -0.0047 D(x): 0.9953 D(G(z)): 0.0048 / 0.0047
[0/5][450/469] Loss_D: 0.0054 Loss_G: -0.0028 D(x): 0.9976 D(G(z)): 0.0030 / 0.0028
[1/5][0/469] Loss_D: 0.0052 Loss_G: -0.0026 D(x): 0.9973 D(G(z)): 0.0025 / 0.0026
[1/5][50/469] Loss_D: 0.0051 Loss_G: -0.0023 D(x): 0.9971 D(G(z)): 0.0022 / 0.0023
[1/5][100/469] Loss_D: 0.0037 Loss_G: -0.0019 D(x): 0.9981 D(G(z)): 0.0018 / 0.0019
[1/5][150/469] Loss_D: 0.0033 Loss_G: -0.0017 D(x): 0.9984 D(G(z)): 0.0018 / 0.0017
[1/5][200/469] Loss_D: 0.0029 Loss_G: -0.0015 D(x): 0.9986 D(G(z)): 0.0015 / 0.0015
[1/5][250/469] Loss_D: 0.0023 Loss_G: -0.0013 D(x): 0.9990 D(G(z)): 0.0013 / 0.0013
[1/5][300/469] Loss_D: 0.0029 Loss_G: -0.0014 D(x): 0.9983 D(G(z)): 0.0012 / 0.0014
[1/5][350/469] Loss_D: 0.0018 Loss_G: -0.0010 D(x): 0.9991 D(G(z)): 0.0010 / 0.0010
[1/5][400/469] Loss_D: 0.0017 Loss_G: -0.0009 D(x): 0.9992 D(G(z)): 0.0009 / 0.0009
[1/5][450/469] Loss_D: 0.0015 Loss_G: -0.0008 D(x): 0.9993 D(G(z)): 0.0008 / 0.0008
[2/5][0/469] Loss_D: 0.0013 Loss_G: -0.0007 D(x): 0.9994 D(G(z)): 0.0007 / 0.0007
[2/5][50/469] Loss_D: 0.0014 Loss_G: -0.0006 D(x): 0.9992 D(G(z)): 0.0006 / 0.0006
[2/5][100/469] Loss_D: 0.0010 Loss_G: -0.0005 D(x): 0.9995 D(G(z)): 0.0005 / 0.0005
[2/5][150/469] Loss_D: 0.0009 Loss_G: -0.0005 D(x): 0.9996 D(G(z)): 0.0005 / 0.0005
[2/5][200/469] Loss_D: 0.0010 Loss_G: -0.0005 D(x): 0.9994 D(G(z)): 0.0005 / 0.0005
[2/5][250/469] Loss_D: 0.0007 Loss_G: -0.0004 D(x): 0.9997 D(G(z)): 0.0005 / 0.0004
[2/5][300/469] Loss_D: 0.0008 Loss_G: -0.0004 D(x): 0.9996 D(G(z)): 0.0004 / 0.0004
[2/5][350/469] Loss_D: 0.0006 Loss_G: -0.0003 D(x): 0.9997 D(G(z)): 0.0003 / 0.0003
[2/5][400/469] Loss_D: 0.0007 Loss_G: -0.0003 D(x): 0.9996 D(G(z)): 0.0003 / 0.0003
[2/5][450/469] Loss_D: 0.0006 Loss_G: -0.0003 D(x): 0.9997 D(G(z)): 0.0003 / 0.0003
[3/5][0/469] Loss_D: 0.0005 Loss_G: -0.0003 D(x): 0.9998 D(G(z)): 0.0003 / 0.0003
[3/5][50/469] Loss_D: 0.0005 Loss_G: -0.0003 D(x): 0.9998 D(G(z)): 0.0003 / 0.0003
[3/5][100/469] Loss_D: 0.0005 Loss_G: -0.0003 D(x): 0.9998 D(G(z)): 0.0003 / 0.0003
[3/5][150/469] Loss_D: 0.0005 Loss_G: -0.0002 D(x): 0.9997 D(G(z)): 0.0002 / 0.0002
[3/5][200/469] Loss_D: 0.0004 Loss_G: -0.0002 D(x): 0.9998 D(G(z)): 0.0002 / 0.0002
[3/5][250/469] Loss_D: 0.0004 Loss_G: -0.0002 D(x): 0.9998 D(G(z)): 0.0002 / 0.0002
[3/5][300/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002
[3/5][350/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9998 D(G(z)): 0.0002 / 0.0002
[3/5][400/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002
[3/5][450/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002
[4/5][0/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9998 D(G(z)): 0.0002 / 0.0002
[4/5][50/469] Loss_D: 0.0003 Loss_G: -0.0002 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002
[4/5][100/469] Loss_D: 0.0003 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][150/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][200/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][250/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][300/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][350/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][400/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][450/469] Loss_D: 0.0002 Loss_G: -0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
initialize_net provided in Task 1.0 to initialize the generator and discriminator function without weight initialization (HINT: There is no need to modify the code for initialize_net function).################################ YOUR CODE ################################
netG_woinit = initialize_net(Generator, None, device, ngpu)
netD_woinit = initialize_net(Discriminator, None, device, ngpu)
########################### END YOUR CODE ###############################
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
# Setup Adam optimizers for both G and D
optimizerD_woinit = optim.Adam(netD_woinit.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_woinit = optim.Adam(netG_woinit.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD_woinit.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD_woinit(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG_woinit(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD_woinit(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD_woinit.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG_woinit.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD_woinit(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG_woinit.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG_woinit(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Starting Training Loop... [0/5][0/469] Loss_D: 1.4309 Loss_G: 0.6981 D(x): 0.5045 D(G(z)): 0.5139 / 0.5033 [0/5][50/469] Loss_D: 0.6658 Loss_G: 1.1141 D(x): 0.8098 D(G(z)): 0.3598 / 0.3353 [0/5][100/469] Loss_D: 0.5116 Loss_G: 1.5265 D(x): 0.8193 D(G(z)): 0.2573 / 0.2270 [0/5][150/469] Loss_D: 0.3615 Loss_G: 1.9621 D(x): 0.8677 D(G(z)): 0.1888 / 0.1485 [0/5][200/469] Loss_D: 0.2263 Loss_G: 2.6019 D(x): 0.9095 D(G(z)): 0.1175 / 0.0861 [0/5][250/469] Loss_D: 0.1635 Loss_G: 2.9754 D(x): 0.9243 D(G(z)): 0.0766 / 0.0570 [0/5][300/469] Loss_D: 0.1121 Loss_G: 3.6956 D(x): 0.9495 D(G(z)): 0.0540 / 0.0297 [0/5][350/469] Loss_D: 0.0574 Loss_G: 3.9042 D(x): 0.9789 D(G(z)): 0.0346 / 0.0219 [0/5][400/469] Loss_D: 0.0531 Loss_G: 4.3394 D(x): 0.9827 D(G(z)): 0.0340 / 0.0163 [0/5][450/469] Loss_D: 0.0471 Loss_G: 4.1084 D(x): 0.9826 D(G(z)): 0.0285 / 0.0184 [1/5][0/469] Loss_D: 0.0308 Loss_G: 4.5069 D(x): 0.9884 D(G(z)): 0.0189 / 0.0123 [1/5][50/469] Loss_D: 0.0430 Loss_G: 4.2864 D(x): 0.9825 D(G(z)): 0.0247 / 0.0160 [1/5][100/469] Loss_D: 0.0927 Loss_G: 3.5971 D(x): 0.9490 D(G(z)): 0.0312 / 0.0341 [1/5][150/469] Loss_D: 0.0441 Loss_G: 4.4140 D(x): 0.9806 D(G(z)): 0.0235 / 0.0158 [1/5][200/469] Loss_D: 0.0621 Loss_G: 4.4977 D(x): 0.9536 D(G(z)): 0.0108 / 0.0142 [1/5][250/469] Loss_D: 0.0329 Loss_G: 4.7403 D(x): 0.9823 D(G(z)): 0.0145 / 0.0116 [1/5][300/469] Loss_D: 0.0405 Loss_G: 4.5048 D(x): 0.9796 D(G(z)): 0.0178 / 0.0137 [1/5][350/469] Loss_D: 0.0449 Loss_G: 4.4689 D(x): 0.9788 D(G(z)): 0.0221 / 0.0141 [1/5][400/469] Loss_D: 0.0592 Loss_G: 4.8047 D(x): 0.9867 D(G(z)): 0.0439 / 0.0100 [1/5][450/469] Loss_D: 0.1017 Loss_G: 3.7506 D(x): 0.9415 D(G(z)): 0.0312 / 0.0313 [2/5][0/469] Loss_D: 0.0515 Loss_G: 4.1536 D(x): 0.9761 D(G(z)): 0.0247 / 0.0207 [2/5][50/469] Loss_D: 0.0617 Loss_G: 4.2446 D(x): 0.9660 D(G(z)): 0.0252 / 0.0209 [2/5][100/469] Loss_D: 0.0938 Loss_G: 4.4978 D(x): 0.9785 D(G(z)): 0.0661 / 0.0160 [2/5][150/469] Loss_D: 0.0752 Loss_G: 3.7000 D(x): 0.9537 D(G(z)): 0.0245 / 0.0366 [2/5][200/469] Loss_D: 0.0975 Loss_G: 4.0539 D(x): 0.9457 D(G(z)): 0.0332 / 0.0229 [2/5][250/469] Loss_D: 0.1918 Loss_G: 3.7465 D(x): 0.9524 D(G(z)): 0.1253 / 0.0346 [2/5][300/469] Loss_D: 0.1045 Loss_G: 3.7999 D(x): 0.9368 D(G(z)): 0.0353 / 0.0317 [2/5][350/469] Loss_D: 0.1044 Loss_G: 3.3083 D(x): 0.9378 D(G(z)): 0.0338 / 0.0474 [2/5][400/469] Loss_D: 0.1094 Loss_G: 3.7568 D(x): 0.9693 D(G(z)): 0.0719 / 0.0322 [2/5][450/469] Loss_D: 0.1122 Loss_G: 3.0594 D(x): 0.9427 D(G(z)): 0.0481 / 0.0654 [3/5][0/469] Loss_D: 1.0178 Loss_G: 3.8874 D(x): 0.5066 D(G(z)): 0.0011 / 0.0359 [3/5][50/469] Loss_D: 0.1128 Loss_G: 3.4805 D(x): 0.9260 D(G(z)): 0.0260 / 0.0427 [3/5][100/469] Loss_D: 0.1609 Loss_G: 2.8178 D(x): 0.9029 D(G(z)): 0.0407 / 0.0821 [3/5][150/469] Loss_D: 0.1238 Loss_G: 3.7801 D(x): 0.9276 D(G(z)): 0.0401 / 0.0375 [3/5][200/469] Loss_D: 0.1913 Loss_G: 3.2935 D(x): 0.9037 D(G(z)): 0.0673 / 0.0600 [3/5][250/469] Loss_D: 0.2128 Loss_G: 2.9394 D(x): 0.8526 D(G(z)): 0.0270 / 0.0764 [3/5][300/469] Loss_D: 0.2324 Loss_G: 2.4964 D(x): 0.8595 D(G(z)): 0.0511 / 0.1113 [3/5][350/469] Loss_D: 0.1967 Loss_G: 2.5937 D(x): 0.9032 D(G(z)): 0.0757 / 0.1081 [3/5][400/469] Loss_D: 0.1936 Loss_G: 2.2668 D(x): 0.8638 D(G(z)): 0.0269 / 0.1451 [3/5][450/469] Loss_D: 0.1141 Loss_G: 3.2765 D(x): 0.9337 D(G(z)): 0.0411 / 0.0552 [4/5][0/469] Loss_D: 0.1290 Loss_G: 3.2662 D(x): 0.9092 D(G(z)): 0.0238 / 0.0574 [4/5][50/469] Loss_D: 0.1190 Loss_G: 3.5920 D(x): 0.9246 D(G(z)): 0.0339 / 0.0416 [4/5][100/469] Loss_D: 0.1664 Loss_G: 3.4862 D(x): 0.9595 D(G(z)): 0.1086 / 0.0497 [4/5][150/469] Loss_D: 0.2272 Loss_G: 3.8006 D(x): 0.9359 D(G(z)): 0.1321 / 0.0375 [4/5][200/469] Loss_D: 0.1660 Loss_G: 2.9409 D(x): 0.8988 D(G(z)): 0.0458 / 0.0896 [4/5][250/469] Loss_D: 0.1796 Loss_G: 3.2309 D(x): 0.9115 D(G(z)): 0.0730 / 0.0637 [4/5][300/469] Loss_D: 0.1116 Loss_G: 3.1267 D(x): 0.9462 D(G(z)): 0.0509 / 0.0630 [4/5][350/469] Loss_D: 0.2065 Loss_G: 2.9305 D(x): 0.9388 D(G(z)): 0.1158 / 0.0728 [4/5][400/469] Loss_D: 0.1649 Loss_G: 3.5358 D(x): 0.9706 D(G(z)): 0.1171 / 0.0455 [4/5][450/469] Loss_D: 0.1826 Loss_G: 2.8879 D(x): 0.8773 D(G(z)): 0.0353 / 0.0892
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
Wasserstein GAN (WGAN) is an alternative training strategy to traditional GAN. WGAN may provide more stable learning and may avoid problems faced in traditional GAN training like mode collapse and vanishing gradient. We will not go through the whole derivation of this algorithm but if interested, you can find more details in the arXiv paper above and Prof. Inouye's lecture notes on Wasserstein GANs from ECE 570.
The objective function of WGAN is still a min-max but with a different objective function: $$ \min_G \max_D \mathbb{E}_{p_{data}}[D(x)] - \mathbb{E}_{p_z}[D(G(z))] \,, $$ where $D$ must be a 1-Lipschitz function (rather than a classifier as in regular GANs) and $p_z$ is a standard normal distribution. Notice the similarities and differences with the original GAN objective: $$ \min_G \max_D \mathbb{E}_{p_{data}}[\log D(x)] + \mathbb{E}_{p_z}[\log (1- D(G(z)))] \,, $$ where $D$ is a classifier. Note in practice the WGAN paper uses multiple discriminators (also called "critics") so they use multiple $D$s during training.
We will not go through the derivation but one approximation algorithm for optimizing the WGAN objective is to apply weight clipping to all the weights, i.e., enforce that their absolute value is smaller than some constant $c$. The full pseudo-algorithm can be found on slide 17 in these slides on WGAN or in the original paper.
Sigmoid layer)lr_rms (which we set to 5e-4, which is larger than the rate in the paper but works better for our purposes).torch.Tensor.clamp_() function to clip the parameter values. You will need to do this for all parameters of the discriminator. See algorithm for when to do this.class Discriminator_WGAN(nn.Module):
def __init__(self, ngpu):
super(Discriminator_WGAN, self).__init__()
self.ngpu = ngpu
################################ YOUR CODE ################################
self.main = nn.Sequential(
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
# nn.Sigmoid() # Produce probability
)
########################### END YOUR CODE ################################
def forward(self, input):
return self.main(input)
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator_WGAN, weights_init, device, ngpu)
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator_WGAN(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
)
)
############################ YOUR CODE ############################
# Setup RMSprop optimizers for both netG and netD with given learning rate as `lr_rms`
######################## # END YOUR CODE ##########################
# Training Loop
optimizerD = optim.RMSprop(netD.parameters(), lr=lr_rms)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr_rms)
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
n_critic = 5
c = 0.01
dataloader_iter = iter(dataloader)
print("Starting Training Loop...")
num_iters = 1000
for iters in range(num_iters):
###########################################################################
# (1) Train Discriminator more: minimize -(mean(D(real))-mean(D(fake)))
###########################################################################
for p in netD.parameters():
p.requires_grad = True
for idx_critic in range(n_critic):
netD.zero_grad()
try:
data = next(dataloader_iter)
except StopIteration:
dataloader_iter = iter(dataloader)
data = next(dataloader_iter)
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
D_real = netD(real_cpu).view(-1)
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
############################ YOUR CODE ############################
# Define your loss function for variable `D_loss`
# Backpropagate the loss function and update the optimizer
# Clip the D network parameters to be within -c and c by using `clamp_()` function
# Note that if all weights are bounded, then the Lipschitz constant is bounded.
# Calculate D's loss on the all-fake batch
D_loss = -(torch.mean(D_real) - torch.mean(D_fake))
# Calculate the gradients for this batch
D_loss.backward()
# D_G_z1 = output.mean().item()
# Update D
optimizerD.step()
for p in netD.parameters():
p.data.clamp_(min=-c, max=c)
######################## # END YOUR CODE ##########################
###########################################################################
# (2) Update G network: minimize -mean(D(fake)) (Update only once in 5 epochs)
###########################################################################
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
################################ YOUR CODE ################################
# Define your loss function for variable `G_loss`
# Backpropagate the loss function and upate the optimizer
G_loss = -torch.mean(D_fake)
# Calculate the gradients for this batch
G_loss.backward()
# D_G_z1 = output.mean().item()
# Update D
optimizerG.step()
############################# END YOUR CODE ##############################
# Output training stats
if iters % 10 == 0:
print('[%4d/%4d] Loss_D: %6.4f Loss_G: %6.4f'
% (iters, num_iters, D_loss.item(), G_loss.item()))
# Save Losses for plotting later
G_losses.append(G_loss.item())
D_losses.append(D_loss.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 100 == 0):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
Starting Training Loop... [ 0/1000] Loss_D: -0.0013 Loss_G: -0.0000 [ 10/1000] Loss_D: -0.0014 Loss_G: -0.0014 [ 20/1000] Loss_D: -0.0015 Loss_G: -0.0010 [ 30/1000] Loss_D: -0.0018 Loss_G: -0.0006 [ 40/1000] Loss_D: -0.0029 Loss_G: 0.0001 [ 50/1000] Loss_D: -0.0043 Loss_G: 0.0004 [ 60/1000] Loss_D: -0.0061 Loss_G: 0.0004 [ 70/1000] Loss_D: -0.0089 Loss_G: 0.0046 [ 80/1000] Loss_D: -0.0088 Loss_G: 0.0002 [ 90/1000] Loss_D: -0.0094 Loss_G: 0.0045 [ 100/1000] Loss_D: -0.0084 Loss_G: 0.0052 [ 110/1000] Loss_D: -0.0084 Loss_G: 0.0025 [ 120/1000] Loss_D: -0.0077 Loss_G: 0.0013 [ 130/1000] Loss_D: -0.0082 Loss_G: 0.0060 [ 140/1000] Loss_D: -0.0083 Loss_G: 0.0070 [ 150/1000] Loss_D: -0.0077 Loss_G: 0.0068 [ 160/1000] Loss_D: -0.0072 Loss_G: 0.0056 [ 170/1000] Loss_D: -0.0069 Loss_G: 0.0072 [ 180/1000] Loss_D: -0.0066 Loss_G: 0.0036 [ 190/1000] Loss_D: -0.0048 Loss_G: 0.0001 [ 200/1000] Loss_D: -0.0062 Loss_G: 0.0089 [ 210/1000] Loss_D: -0.0056 Loss_G: 0.0002 [ 220/1000] Loss_D: -0.0057 Loss_G: 0.0058 [ 230/1000] Loss_D: -0.0065 Loss_G: 0.0085 [ 240/1000] Loss_D: -0.0054 Loss_G: 0.0025 [ 250/1000] Loss_D: -0.0050 Loss_G: 0.0006 [ 260/1000] Loss_D: -0.0055 Loss_G: 0.0063 [ 270/1000] Loss_D: -0.0044 Loss_G: 0.0036 [ 280/1000] Loss_D: -0.0036 Loss_G: 0.0011 [ 290/1000] Loss_D: -0.0040 Loss_G: -0.0012 [ 300/1000] Loss_D: -0.0041 Loss_G: 0.0014 [ 310/1000] Loss_D: -0.0042 Loss_G: -0.0019 [ 320/1000] Loss_D: -0.0028 Loss_G: -0.0004 [ 330/1000] Loss_D: -0.0037 Loss_G: -0.0011 [ 340/1000] Loss_D: -0.0035 Loss_G: -0.0002 [ 350/1000] Loss_D: -0.0029 Loss_G: 0.0054 [ 360/1000] Loss_D: -0.0033 Loss_G: 0.0056 [ 370/1000] Loss_D: -0.0032 Loss_G: 0.0025 [ 380/1000] Loss_D: -0.0041 Loss_G: 0.0077 [ 390/1000] Loss_D: -0.0035 Loss_G: 0.0018 [ 400/1000] Loss_D: -0.0033 Loss_G: 0.0027 [ 410/1000] Loss_D: -0.0035 Loss_G: 0.0101 [ 420/1000] Loss_D: -0.0025 Loss_G: -0.0029 [ 430/1000] Loss_D: -0.0030 Loss_G: 0.0015 [ 440/1000] Loss_D: -0.0029 Loss_G: 0.0016 [ 450/1000] Loss_D: -0.0030 Loss_G: 0.0066 [ 460/1000] Loss_D: -0.0018 Loss_G: 0.0065 [ 470/1000] Loss_D: -0.0024 Loss_G: 0.0010 [ 480/1000] Loss_D: -0.0036 Loss_G: 0.0042 [ 490/1000] Loss_D: -0.0027 Loss_G: -0.0003 [ 500/1000] Loss_D: -0.0028 Loss_G: 0.0012 [ 510/1000] Loss_D: -0.0027 Loss_G: 0.0043 [ 520/1000] Loss_D: -0.0028 Loss_G: 0.0028 [ 530/1000] Loss_D: -0.0025 Loss_G: -0.0001 [ 540/1000] Loss_D: -0.0027 Loss_G: -0.0001 [ 550/1000] Loss_D: -0.0027 Loss_G: 0.0028 [ 560/1000] Loss_D: -0.0028 Loss_G: -0.0026 [ 570/1000] Loss_D: -0.0029 Loss_G: 0.0025 [ 580/1000] Loss_D: -0.0026 Loss_G: 0.0010 [ 590/1000] Loss_D: -0.0021 Loss_G: -0.0007 [ 600/1000] Loss_D: -0.0024 Loss_G: 0.0010 [ 610/1000] Loss_D: -0.0021 Loss_G: -0.0008 [ 620/1000] Loss_D: -0.0027 Loss_G: 0.0020 [ 630/1000] Loss_D: -0.0028 Loss_G: 0.0082 [ 640/1000] Loss_D: -0.0023 Loss_G: 0.0053 [ 650/1000] Loss_D: -0.0025 Loss_G: 0.0009 [ 660/1000] Loss_D: -0.0024 Loss_G: -0.0041 [ 670/1000] Loss_D: -0.0025 Loss_G: 0.0014 [ 680/1000] Loss_D: -0.0027 Loss_G: -0.0020 [ 690/1000] Loss_D: -0.0019 Loss_G: 0.0026 [ 700/1000] Loss_D: -0.0020 Loss_G: 0.0057 [ 710/1000] Loss_D: -0.0023 Loss_G: 0.0024 [ 720/1000] Loss_D: -0.0025 Loss_G: 0.0037 [ 730/1000] Loss_D: -0.0020 Loss_G: 0.0042 [ 740/1000] Loss_D: -0.0024 Loss_G: 0.0062 [ 750/1000] Loss_D: -0.0024 Loss_G: 0.0022 [ 760/1000] Loss_D: -0.0025 Loss_G: 0.0026 [ 770/1000] Loss_D: -0.0018 Loss_G: 0.0001 [ 780/1000] Loss_D: -0.0021 Loss_G: 0.0032 [ 790/1000] Loss_D: -0.0024 Loss_G: -0.0005 [ 800/1000] Loss_D: -0.0019 Loss_G: -0.0013 [ 810/1000] Loss_D: -0.0023 Loss_G: 0.0020 [ 820/1000] Loss_D: -0.0021 Loss_G: 0.0008 [ 830/1000] Loss_D: -0.0023 Loss_G: -0.0027 [ 840/1000] Loss_D: -0.0020 Loss_G: -0.0008 [ 850/1000] Loss_D: -0.0024 Loss_G: 0.0029 [ 860/1000] Loss_D: -0.0016 Loss_G: 0.0020 [ 870/1000] Loss_D: -0.0016 Loss_G: 0.0039 [ 880/1000] Loss_D: -0.0023 Loss_G: -0.0063 [ 890/1000] Loss_D: -0.0021 Loss_G: -0.0056 [ 900/1000] Loss_D: -0.0017 Loss_G: 0.0025 [ 910/1000] Loss_D: -0.0019 Loss_G: 0.0001 [ 920/1000] Loss_D: -0.0017 Loss_G: 0.0014 [ 930/1000] Loss_D: -0.0022 Loss_G: 0.0010 [ 940/1000] Loss_D: -0.0020 Loss_G: -0.0018 [ 950/1000] Loss_D: -0.0016 Loss_G: 0.0006 [ 960/1000] Loss_D: -0.0024 Loss_G: -0.0019 [ 970/1000] Loss_D: -0.0018 Loss_G: 0.0028 [ 980/1000] Loss_D: -0.0022 Loss_G: 0.0065 [ 990/1000] Loss_D: -0.0022 Loss_G: -0.0038
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
Use slide 19 in Lecture note for WGAN to implement WGAN-GP algorithm.
torch.autograd.grad. You will need to set:outputsinputsgrad_outputscreate_graph=True and retain_graph=True (because we want to backprop through this gradient calculation for the final objective.)grad_norm = torch.sqrt((grad**2).sum(1) + 1e-14) is a simple way to compute the norm.) Train the model with modified networks and visualize the results.
# Setup networks for WGAN-GP
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator_WGAN, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=5e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=5e-4, betas=(0.5, 0.9))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
n_critic = 5
dataloader_iter = iter(dataloader)
print("Starting Training Loop...")
num_iters = 1000
for iters in range(num_iters):
###########################################################################
# (1) Train Discriminator more: minimize -(mean(D(real))-mean(D(fake)))+GP
###########################################################################
for p in netD.parameters():
p.requires_grad = True
for idx_critic in range(n_critic):
netD.zero_grad()
try:
data = next(dataloader_iter)
except StopIteration:
dataloader_iter = iter(dataloader)
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
D_real = netD(real_cpu).view(-1)
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
############################ YOUR CODE ############################
# Compute the gradient penalty term
# Define your loss function for variable `D_loss`
# Backpropagate the loss function and upate the optimizer
######################## # END YOUR CODE ##########################
###########################################################################
# (2) Update G network: minimize -mean(D(fake)) (Update only once in 5 epochs)
###########################################################################
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
################################ YOUR CODE ################################
# Define your loss function for variable `G_loss`
# Backpropagate the loss function and upate the optimizer
############################# END YOUR CODE ##############################
# Output training stats
if iters % 10 == 0:
print('[%4d/%4d] Loss_D: %6.4f Loss_G: %6.4f'
% (iters, num_iters, D_loss.item(), G_loss.item()))
# Save Losses for plotting later
G_losses.append(G_loss.item())
D_losses.append(D_loss.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 100 == 0):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)